{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "animeGAN.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/ptran1203/pytorch-animeGAN/blob/master/notebooks/animeGAN.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rvh6-cYkHY90"
      },
      "source": [
        "from google.colab import drive, output\n",
        "import torch\n",
        "dv = torch.cuda.get_device_name(0)\n",
        "print(dv)\n",
        "\n",
        "drive.mount('/content/drive', force_remount=False)\n",
        "repo = \"Pytorch-animeGAN\"\n",
        "%cd \"/content\"\n",
        "!rm -rf {repo}\n",
        "!git clone https://github.com/ptran1203/{repo}\n",
        "%cd {repo}\n",
        "output.clear()"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x0_9-Gd3AZKk"
      },
      "source": [
        "import os\n",
        "import urllib\n",
        "\n",
        "data_path = 'anime-gan.zip'\n",
        "dataset_url = 'https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/dataset_v1.zip'\n",
        "\n",
        "if not os.path.exists(\"/content/dataset\"):\n",
        "    !wget -O {data_path} {dataset_url}\n",
        "    !unzip {data_path} -d /content\n",
        "    !rm {data_path}\n",
        "\n",
        "    if not os.path.exists(\"/content/dataset\"):\n",
        "        raise ValueError(f\"Download Failed, {data_path}\")\n",
        "\n",
        "output.clear()"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "62WdajGhbB2h"
      },
      "source": [
        "working_dir = '/content/drive/MyDrive/animeGAN'\n",
        "ckp_dir = f'{working_dir}/checkpoints'\n",
        "save_img_dir = f'{working_dir}/generated_samples'\n",
        "print(f\"You're running on {dv}\")"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cn88CEIiHWE3"
      },
      "source": [
        "!python3 train.py --dataset 'Shinkai'\\\n",
        "                  --batch 6\\\n",
        "                  --debug-samples 0\\\n",
        "                  --init-epochs 10\\\n",
        "                  --checkpoint-dir {ckp_dir}\\\n",
        "                  --save-image-dir {save_img_dir}\\\n",
        "                  --save-interval 1\\\n",
        "                  --gan-loss lsgan\\\n",
        "                  --init-lr 0.0001\\\n",
        "                  --lr-g 0.00002\\\n",
        "                  --lr-d 0.00004\\\n",
        "                  --wadvd 10.0\\\n",
        "                  --wadvg 10.0\\\n",
        "                  --wcon 1.5\\\n",
        "                  --wgra 3.0\\\n",
        "                  --wcol 70.0\\\n",
        "                  --resume GD\\\n",
        "                  --use_sn\\"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f3BSs5dkAFXP",
        "outputId": "51bf178f-d78f-4702-9c2a-125ce25621bb",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "from inference import Transformer\n",
        "transformer = Transformer(ckp_dir + '/generator_Shinkai.pth')"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Weight loaded, ready to predict\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0hTIi5dvkzAZ"
      },
      "source": [
        "import cv2\n",
        "import random\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "def random_img(img_dir):\n",
        "    # p = '/content/test_4.png'\n",
        "    p = os.path.join(img_dir, random.choice(os.listdir(img_dir)))\n",
        "    return cv2.imread(p)[:, :, ::-1]\n",
        "\n",
        "image = random_img('/content/dataset/test/HR_photo')\n",
        "image = cv2.resize(image, (768, 512))\n",
        "\n",
        "anime_img = (transformer.transform(image) + 1) / 2\n",
        "\n",
        "fig = plt.figure(figsize=(18, 25))\n",
        "fig.add_subplot(1, 2, 1)\n",
        "plt.imshow(image)\n",
        "fig.add_subplot(1, 2, 2)\n",
        "plt.imshow(anime_img[0])\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l1STqcCd4VcQ"
      },
      "source": [
        "Inference Image"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nxiL259r4Uqi",
        "outputId": "38181834-882f-42c4-c5ac-3674d91e2c84",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "# !python3 inference_image.py --checkpoint {ckp_dir}/generator_Shinkai.pth\\\n",
        "#                             --src /content/dataset/test/HR_photo\\\n",
        "#                             --dest {working_dir}/inference_shinkai"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Weight loaded, ready to predict\n",
            "Found 45 images in /content/dataset/test/HR_photo\n",
            "100% 45/45 [00:24<00:00,  1.82it/s]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vb_Jxx7gO-vp"
      },
      "source": [
        "Inference video"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0WFuk0cfPAzl"
      },
      "source": [
        "# !python3 inference_video.py --checkpoint {ckp_dir}\\\n",
        "#                             --src /content/test_vid_3.mp4\\\n",
        "#                             --dest /content/test_vid_3_anime.mp4\\\n",
        "#                             --batch-size 2"
      ],
      "execution_count": 8,
      "outputs": []
    }
  ]
}